import pandas as pd
import os


def stratified_sample(file_path, sample_ratio=0.3):
    # 读取TSV文件
    df = pd.read_csv(file_path, sep='\t')

    # 按标签分层抽样
    sampled = df.groupby('label', group_keys=False).apply(
        lambda x: x.sample(frac=sample_ratio, random_state=42)
    )

    return sampled.sample(frac=1)  # 打乱顺序


base_dir = "thucnews_tiny_tsv"
output_dir = "thucnews_t2_tsv" # 与原目录保持一致

for split in ['dev', 'test', 'train']:
    input_file = os.path.join(base_dir, f"{split}.tsv")
    output_file = os.path.join(output_dir, f"{split}.tsv")

    sampled_df = stratified_sample(input_file)
    sampled_df.to_csv(output_file, sep='\t', index=False)
    print(f"Created {output_file} with {len(sampled_df)} samples")




